# Copyright (c) OpenMMLab. All rights reserved.
import os
from tqdm import tqdm
import sys
import time, gc

import asyncio
from argparse import ArgumentParser

from mmdet.apis import (async_inference_detector, inference_detector,
                        init_detector, show_result_pyplot)


def parse_args():
    parser = ArgumentParser()
    parser.add_argument('imgList', help='Image file')
    parser.add_argument('config', help='Config file')
    parser.add_argument('checkpoint', help='Checkpoint file')
    parser.add_argument('out_file_path', help='output file dir path')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--score-thr', type=float, default=0.3, help='bbox score threshold')
    parser.add_argument(
        '--async-test',
        action='store_true',
        help='whether to set async options for async inference.')
    args = parser.parse_args()
    return args


def main(args):
    # build the model from a config file and a checkpoint file
    model = init_detector(args.config, args.checkpoint, device=args.device)

    # test a list image
    with open(args.imgList, 'r') as fr:
        imgList = fr.readlines()

        for img_path in tqdm(imgList):

            name = os.path.basename(img_path).split(".jpg")[0]
            out_file_path = os.path.join(args.out_file_path, name + ".jpg")

            result = inference_detector(model, img_path.strip())
            # result = inference_detector(model, img_path.strip())

            model.show_result(img_path.strip(),
                    result,
                    score_thr=0.3,
                    bbox_color=(72, 101, 241),
                    text_color=(72, 101, 241),
                    mask_color=None,
                    thickness=2,
                    font_size=13,
                    win_name='',
                    show=False,
                    wait_time=0,
                    out_file=out_file_path)

            # show the results
            # show_result_pyplot(model, img_path.strip(), result, score_thr=args.score_thr)



async def async_main(args):
    # build the model from a config file and a checkpoint file
    model = init_detector(args.config, args.checkpoint, device=args.device)
    # test a single image
    tasks = asyncio.create_task(async_inference_detector(model, args.imgList))
    result = await asyncio.gather(tasks)
    # show the results
    show_result_pyplot(model, args.imgList, result[0], score_thr=args.score_thr)


if __name__ == '__main__':

    args = parse_args()
    if args.async_test:
        asyncio.run(async_main(args))
    else:
        main(args)


    """


    import os
    base_dir = "/home/musk/video_bak/view/frames"  # 根目录， 存储索引

    file_count = 1
    img_count = 0
    names_list = []

    img_dir_name = "plane_frames"

    img_dir = os.path.join(base_dir, img_dir_name)
    for name in os.listdir(img_dir):
        names_list.append(img_dir + "/" + name + "\n")

        img_count = img_count + 1

        list_dir = os.path.join(base_dir, img_dir_name+"_list")
        if not os.path.exists(list_dir):
            os.mkdir(list_dir)

        if img_count >= 100:
            # names_list[-1].strip()
            with open(os.path.join(list_dir, "i_%s.txt" % file_count), 'w') as fw:
                fw.writelines(names_list)

            img_count = 0
            file_count = file_count + 1
            names_list = []

    with open(os.path.join(list_dir, "i_%s.txt" % file_count), 'w') as fw:
        fw.writelines(names_list)

    """